{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import Scalar, draw_graph\n",
    "from linear_model import Linear, mse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"357pt\" height=\"452pt\"\n",
       " viewBox=\"0.00 0.00 356.79 451.61\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 447.609)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-447.609 352.792,-447.609 352.792,4 -4,4\"/>\n",
       "<!-- 140621852371504forward -->\n",
       "<g id=\"node1\" class=\"node\"><title>140621852371504forward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M14,-192.344C14,-192.344 66.2051,-192.344 66.2051,-192.344 72.2051,-192.344 78.2051,-198.344 78.2051,-204.344 78.2051,-204.344 78.2051,-239.266 78.2051,-239.266 78.2051,-245.266 72.2051,-251.266 66.2051,-251.266 66.2051,-251.266 14,-251.266 14,-251.266 8,-251.266 2,-245.266 2,-239.266 2,-239.266 2,-204.344 2,-204.344 2,-198.344 8,-192.344 14,-192.344\"/>\n",
       "<text text-anchor=\"middle\" x=\"40.1025\" y=\"-239.266\" font-family=\"Menlo\" font-size=\"10.00\">grad=None</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"2,-231.625 78.2051,-231.625 \"/>\n",
       "<text text-anchor=\"middle\" x=\"40.1025\" y=\"-219.625\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"2,-211.984 78.2051,-211.984 \"/>\n",
       "<text text-anchor=\"middle\" x=\"40.1025\" y=\"-199.984\" font-family=\"Menlo\" font-size=\"10.00\">+</text>\n",
       "</g>\n",
       "<!-- 140621852371600forward -->\n",
       "<g id=\"node3\" class=\"node\"><title>140621852371600forward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M104,-288.266C104,-288.266 156.205,-288.266 156.205,-288.266 162.205,-288.266 168.205,-294.266 168.205,-300.266 168.205,-300.266 168.205,-335.188 168.205,-335.188 168.205,-341.188 162.205,-347.188 156.205,-347.188 156.205,-347.188 104,-347.188 104,-347.188 98,-347.188 92,-341.188 92,-335.188 92,-335.188 92,-300.266 92,-300.266 92,-294.266 98,-288.266 104,-288.266\"/>\n",
       "<text text-anchor=\"middle\" x=\"130.103\" y=\"-335.188\" font-family=\"Menlo\" font-size=\"10.00\">grad=None</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"92,-327.547 168.205,-327.547 \"/>\n",
       "<text text-anchor=\"middle\" x=\"130.103\" y=\"-315.547\" font-family=\"Menlo\" font-size=\"10.00\">value=1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"92,-307.906 168.205,-307.906 \"/>\n",
       "<text text-anchor=\"middle\" x=\"130.103\" y=\"-295.906\" font-family=\"Menlo\" font-size=\"10.00\">&#45;</text>\n",
       "</g>\n",
       "<!-- 140621852371504forward&#45;&gt;140621852371600forward -->\n",
       "<g id=\"edge5\" class=\"edge\"><title>140621852371504forward&#45;&gt;140621852371600forward</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M67.5945,-251.495C76.4924,-260.78 86.4722,-271.195 95.8008,-280.93\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"93.3331,-283.414 102.779,-288.212 98.3872,-278.57 93.3331,-283.414\"/>\n",
       "</g>\n",
       "<!-- 140621852371072forward -->\n",
       "<g id=\"node2\" class=\"node\"><title>140621852371072forward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"38.1025\" cy=\"-29.9609\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">x1=1.50</text>\n",
       "</g>\n",
       "<!-- 140621852371120forward -->\n",
       "<g id=\"node4\" class=\"node\"><title>140621852371120forward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M12,-96.4219C12,-96.4219 64.2051,-96.4219 64.2051,-96.4219 70.2051,-96.4219 76.2051,-102.422 76.2051,-108.422 76.2051,-108.422 76.2051,-143.344 76.2051,-143.344 76.2051,-149.344 70.2051,-155.344 64.2051,-155.344 64.2051,-155.344 12,-155.344 12,-155.344 6,-155.344 0,-149.344 0,-143.344 0,-143.344 0,-108.422 0,-108.422 0,-102.422 6,-96.4219 12,-96.4219\"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-143.344\" font-family=\"Menlo\" font-size=\"10.00\">grad=None</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"0,-135.703 76.2051,-135.703 \"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-123.703\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"0,-116.062 76.2051,-116.062 \"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-104.062\" font-family=\"Menlo\" font-size=\"10.00\">*</text>\n",
       "</g>\n",
       "<!-- 140621852371072forward&#45;&gt;140621852371120forward -->\n",
       "<g id=\"edge11\" class=\"edge\"><title>140621852371072forward&#45;&gt;140621852371120forward</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M38.1025,-48.1877C38.1025,-58.8016 38.1025,-72.8581 38.1025,-86.0355\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"34.6026,-86.3326 38.1025,-96.3326 41.6026,-86.3326 34.6026,-86.3326\"/>\n",
       "</g>\n",
       "<!-- 140621852371792forward -->\n",
       "<g id=\"node10\" class=\"node\"><title>140621852371792forward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M151,-384.188C151,-384.188 203.205,-384.188 203.205,-384.188 209.205,-384.188 215.205,-390.188 215.205,-396.188 215.205,-396.188 215.205,-431.109 215.205,-431.109 215.205,-437.109 209.205,-443.109 203.205,-443.109 203.205,-443.109 151,-443.109 151,-443.109 145,-443.109 139,-437.109 139,-431.109 139,-431.109 139,-396.188 139,-396.188 139,-390.188 145,-384.188 151,-384.188\"/>\n",
       "<text text-anchor=\"middle\" x=\"177.103\" y=\"-431.109\" font-family=\"Menlo\" font-size=\"10.00\">grad=None</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"139,-423.469 215.205,-423.469 \"/>\n",
       "<text text-anchor=\"middle\" x=\"177.103\" y=\"-411.469\" font-family=\"Menlo\" font-size=\"10.00\">value=8.50</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"139,-403.828 215.205,-403.828 \"/>\n",
       "<text text-anchor=\"middle\" x=\"177.103\" y=\"-391.828\" font-family=\"Menlo\" font-size=\"10.00\">mse</text>\n",
       "</g>\n",
       "<!-- 140621852371600forward&#45;&gt;140621852371792forward -->\n",
       "<g id=\"edge13\" class=\"edge\"><title>140621852371600forward&#45;&gt;140621852371792forward</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M144.459,-347.417C148.833,-356.156 153.707,-365.895 158.326,-375.127\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"155.228,-376.758 162.834,-384.134 161.488,-373.625 155.228,-376.758\"/>\n",
       "</g>\n",
       "<!-- 140621852371120forward&#45;&gt;140621852371504forward -->\n",
       "<g id=\"edge6\" class=\"edge\"><title>140621852371120forward&#45;&gt;140621852371504forward</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M38.7135,-155.573C38.8918,-163.948 39.0897,-173.242 39.2789,-182.127\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"35.7831,-182.367 39.4953,-192.29 42.7816,-182.218 35.7831,-182.367\"/>\n",
       "</g>\n",
       "<!-- 140621852371648forward -->\n",
       "<g id=\"node5\" class=\"node\"><title>140621852371648forward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M200,-96.4219C200,-96.4219 252.205,-96.4219 252.205,-96.4219 258.205,-96.4219 264.205,-102.422 264.205,-108.422 264.205,-108.422 264.205,-143.344 264.205,-143.344 264.205,-149.344 258.205,-155.344 252.205,-155.344 252.205,-155.344 200,-155.344 200,-155.344 194,-155.344 188,-149.344 188,-143.344 188,-143.344 188,-108.422 188,-108.422 188,-102.422 194,-96.4219 200,-96.4219\"/>\n",
       "<text text-anchor=\"middle\" x=\"226.103\" y=\"-143.344\" font-family=\"Menlo\" font-size=\"10.00\">grad=None</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"188,-135.703 264.205,-135.703 \"/>\n",
       "<text text-anchor=\"middle\" x=\"226.103\" y=\"-123.703\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"188,-116.062 264.205,-116.062 \"/>\n",
       "<text text-anchor=\"middle\" x=\"226.103\" y=\"-104.062\" font-family=\"Menlo\" font-size=\"10.00\">*</text>\n",
       "</g>\n",
       "<!-- 140621852371696forward -->\n",
       "<g id=\"node7\" class=\"node\"><title>140621852371696forward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M197,-192.344C197,-192.344 249.205,-192.344 249.205,-192.344 255.205,-192.344 261.205,-198.344 261.205,-204.344 261.205,-204.344 261.205,-239.266 261.205,-239.266 261.205,-245.266 255.205,-251.266 249.205,-251.266 249.205,-251.266 197,-251.266 197,-251.266 191,-251.266 185,-245.266 185,-239.266 185,-239.266 185,-204.344 185,-204.344 185,-198.344 191,-192.344 197,-192.344\"/>\n",
       "<text text-anchor=\"middle\" x=\"223.103\" y=\"-239.266\" font-family=\"Menlo\" font-size=\"10.00\">grad=None</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"185,-231.625 261.205,-231.625 \"/>\n",
       "<text text-anchor=\"middle\" x=\"223.103\" y=\"-219.625\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"185,-211.984 261.205,-211.984 \"/>\n",
       "<text text-anchor=\"middle\" x=\"223.103\" y=\"-199.984\" font-family=\"Menlo\" font-size=\"10.00\">+</text>\n",
       "</g>\n",
       "<!-- 140621852371648forward&#45;&gt;140621852371696forward -->\n",
       "<g id=\"edge2\" class=\"edge\"><title>140621852371648forward&#45;&gt;140621852371696forward</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M225.186,-155.573C224.919,-163.948 224.622,-173.242 224.338,-182.127\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"220.834,-182.184 224.013,-192.29 227.831,-182.407 220.834,-182.184\"/>\n",
       "</g>\n",
       "<!-- 140621852371168forward -->\n",
       "<g id=\"node6\" class=\"node\"><title>140621852371168forward</title>\n",
       "<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"94,-0.5 94,-59.4219 170.205,-59.4219 170.205,-0.5 94,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-47.4219\" font-family=\"Menlo\" font-size=\"10.00\">grad=None</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"94,-39.7812 170.205,-39.7812 \"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"94,-20.1406 170.205,-20.1406 \"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-8.14062\" font-family=\"Menlo\" font-size=\"10.00\">a</text>\n",
       "</g>\n",
       "<!-- 140621852371168forward&#45;&gt;140621852371120forward -->\n",
       "<g id=\"edge12\" class=\"edge\"><title>140621852371168forward&#45;&gt;140621852371120forward</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M103.389,-59.6509C94.0954,-68.9365 83.672,-79.3512 73.9288,-89.0864\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"71.2407,-86.8244 66.6406,-96.3685 76.1885,-91.7763 71.2407,-86.8244\"/>\n",
       "</g>\n",
       "<!-- 140621852371168forward&#45;&gt;140621852371648forward -->\n",
       "<g id=\"edge4\" class=\"edge\"><title>140621852371168forward&#45;&gt;140621852371648forward</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M160.816,-59.6509C170.11,-68.9365 180.533,-79.3512 190.276,-89.0864\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"188.017,-91.7763 197.565,-96.3685 192.964,-86.8244 188.017,-91.7763\"/>\n",
       "</g>\n",
       "<!-- 140621852371744forward -->\n",
       "<g id=\"node8\" class=\"node\"><title>140621852371744forward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M198,-288.266C198,-288.266 250.205,-288.266 250.205,-288.266 256.205,-288.266 262.205,-294.266 262.205,-300.266 262.205,-300.266 262.205,-335.188 262.205,-335.188 262.205,-341.188 256.205,-347.188 250.205,-347.188 250.205,-347.188 198,-347.188 198,-347.188 192,-347.188 186,-341.188 186,-335.188 186,-335.188 186,-300.266 186,-300.266 186,-294.266 192,-288.266 198,-288.266\"/>\n",
       "<text text-anchor=\"middle\" x=\"224.103\" y=\"-335.188\" font-family=\"Menlo\" font-size=\"10.00\">grad=None</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"186,-327.547 262.205,-327.547 \"/>\n",
       "<text text-anchor=\"middle\" x=\"224.103\" y=\"-315.547\" font-family=\"Menlo\" font-size=\"10.00\">value=4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"186,-307.906 262.205,-307.906 \"/>\n",
       "<text text-anchor=\"middle\" x=\"224.103\" y=\"-295.906\" font-family=\"Menlo\" font-size=\"10.00\">&#45;</text>\n",
       "</g>\n",
       "<!-- 140621852371696forward&#45;&gt;140621852371744forward -->\n",
       "<g id=\"edge1\" class=\"edge\"><title>140621852371696forward&#45;&gt;140621852371744forward</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M223.408,-251.495C223.497,-259.87 223.596,-269.164 223.691,-278.049\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"220.193,-278.25 223.799,-288.212 227.192,-278.176 220.193,-278.25\"/>\n",
       "</g>\n",
       "<!-- 140621852371744forward&#45;&gt;140621852371792forward -->\n",
       "<g id=\"edge14\" class=\"edge\"><title>140621852371744forward&#45;&gt;140621852371792forward</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M209.746,-347.417C205.372,-356.156 200.498,-365.895 195.879,-375.127\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"192.717,-373.625 191.372,-384.134 198.977,-376.758 192.717,-373.625\"/>\n",
       "</g>\n",
       "<!-- 140621852371264forward -->\n",
       "<g id=\"node9\" class=\"node\"><title>140621852371264forward</title>\n",
       "<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"94,-96.4219 94,-155.344 170.205,-155.344 170.205,-96.4219 94,-96.4219\"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-143.344\" font-family=\"Menlo\" font-size=\"10.00\">grad=None</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"94,-135.703 170.205,-135.703 \"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-123.703\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"94,-116.062 170.205,-116.062 \"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-104.062\" font-family=\"Menlo\" font-size=\"10.00\">b</text>\n",
       "</g>\n",
       "<!-- 140621852371264forward&#45;&gt;140621852371504forward -->\n",
       "<g id=\"edge3\" class=\"edge\"><title>140621852371264forward&#45;&gt;140621852371504forward</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M104,-155.573C94.904,-164.858 84.7025,-175.273 75.1665,-185.008\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"72.5307,-182.697 68.0334,-192.29 77.5313,-187.596 72.5307,-182.697\"/>\n",
       "</g>\n",
       "<!-- 140621852371264forward&#45;&gt;140621852371696forward -->\n",
       "<g id=\"edge9\" class=\"edge\"><title>140621852371264forward&#45;&gt;140621852371696forward</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M159.9,-155.573C168.897,-164.858 178.987,-175.273 188.42,-185.008\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"186.003,-187.544 195.475,-192.29 191.03,-182.673 186.003,-187.544\"/>\n",
       "</g>\n",
       "<!-- 140621852371312forward -->\n",
       "<g id=\"node11\" class=\"node\"><title>140621852371312forward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"226.103\" cy=\"-29.9609\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"226.103\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">x2=2.00</text>\n",
       "</g>\n",
       "<!-- 140621852371312forward&#45;&gt;140621852371648forward -->\n",
       "<g id=\"edge7\" class=\"edge\"><title>140621852371312forward&#45;&gt;140621852371648forward</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M226.103,-48.1877C226.103,-58.8016 226.103,-72.8581 226.103,-86.0355\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"222.603,-86.3326 226.103,-96.3326 229.603,-86.3326 222.603,-86.3326\"/>\n",
       "</g>\n",
       "<!-- 140621852371360forward -->\n",
       "<g id=\"node12\" class=\"node\"><title>140621852371360forward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"131.103\" cy=\"-221.805\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"131.103\" y=\"-219.625\" font-family=\"Menlo\" font-size=\"10.00\">y1=1.00</text>\n",
       "</g>\n",
       "<!-- 140621852371360forward&#45;&gt;140621852371600forward -->\n",
       "<g id=\"edge10\" class=\"edge\"><title>140621852371360forward&#45;&gt;140621852371600forward</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M130.919,-240.031C130.806,-250.645 130.656,-264.702 130.516,-277.879\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"127.013,-278.14 130.407,-288.176 134.013,-278.214 127.013,-278.14\"/>\n",
       "</g>\n",
       "<!-- 140621852371408forward -->\n",
       "<g id=\"node13\" class=\"node\"><title>140621852371408forward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"314.103\" cy=\"-221.805\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"314.103\" y=\"-219.625\" font-family=\"Menlo\" font-size=\"10.00\">y2=4.00</text>\n",
       "</g>\n",
       "<!-- 140621852371408forward&#45;&gt;140621852371744forward -->\n",
       "<g id=\"edge8\" class=\"edge\"><title>140621852371408forward&#45;&gt;140621852371744forward</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M299.241,-238.314C288.272,-249.761 272.881,-265.822 258.883,-280.43\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"256.113,-278.263 251.721,-287.905 261.167,-283.106 256.113,-278.263\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fe5138eec10>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算图膨胀\n",
    "model = Linear()\n",
    "# 定义训练数据\n",
    "x1 = Scalar(1.5, label='x1', requires_grad=False)\n",
    "y1 = Scalar(1.0, label='y1', requires_grad=False)\n",
    "x2 = Scalar(2.0, label='x2', requires_grad=False)\n",
    "y2 = Scalar(4.0, label='y2', requires_grad=False)\n",
    "\n",
    "loss = mse([model.error(x1, y1), model.error(x2, y2)])\n",
    "draw_graph(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"357pt\" height=\"517pt\"\n",
       " viewBox=\"0.00 0.00 356.79 516.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 512.797)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-512.797 352.792,-512.797 352.792,4 -4,4\"/>\n",
       "<!-- 140621852371504backward -->\n",
       "<g id=\"node1\" class=\"node\"><title>140621852371504backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M15,-224.938C15,-224.938 67.2051,-224.938 67.2051,-224.938 73.2051,-224.938 79.2051,-230.938 79.2051,-236.938 79.2051,-236.938 79.2051,-271.859 79.2051,-271.859 79.2051,-277.859 73.2051,-283.859 67.2051,-283.859 67.2051,-283.859 15,-283.859 15,-283.859 9,-283.859 3,-277.859 3,-271.859 3,-271.859 3,-236.938 3,-236.938 3,-230.938 9,-224.938 15,-224.938\"/>\n",
       "<text text-anchor=\"middle\" x=\"41.1025\" y=\"-271.859\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"3,-264.219 79.2051,-264.219 \"/>\n",
       "<text text-anchor=\"middle\" x=\"41.1025\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"3,-244.578 79.2051,-244.578 \"/>\n",
       "<text text-anchor=\"middle\" x=\"41.1025\" y=\"-232.578\" font-family=\"Menlo\" font-size=\"10.00\">+</text>\n",
       "</g>\n",
       "<!-- 140621852371120backward -->\n",
       "<g id=\"node4\" class=\"node\"><title>140621852371120backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M12,-112.719C12,-112.719 64.2051,-112.719 64.2051,-112.719 70.2051,-112.719 76.2051,-118.719 76.2051,-124.719 76.2051,-124.719 76.2051,-159.641 76.2051,-159.641 76.2051,-165.641 70.2051,-171.641 64.2051,-171.641 64.2051,-171.641 12,-171.641 12,-171.641 6,-171.641 0,-165.641 0,-159.641 0,-159.641 0,-124.719 0,-124.719 0,-118.719 6,-112.719 12,-112.719\"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-159.641\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"0,-152 76.2051,-152 \"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-140\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"0,-132.359 76.2051,-132.359 \"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-120.359\" font-family=\"Menlo\" font-size=\"10.00\">*</text>\n",
       "</g>\n",
       "<!-- 140621852371504backward&#45;&gt;140621852371120backward -->\n",
       "<g id=\"edge6\" class=\"edge\"><title>140621852371504backward&#45;&gt;140621852371120backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M37.4497,-224.63C36.8287,-218.628 36.2809,-212.334 35.959,-206.438 35.5251,-198.491 35.5083,-189.978 35.7072,-181.855\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"39.2055,-181.964 36.0595,-171.847 32.2098,-181.717 39.2055,-181.964\"/>\n",
       "<text text-anchor=\"middle\" x=\"57.1743\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.00</text>\n",
       "</g>\n",
       "<!-- 140621852371264backward -->\n",
       "<g id=\"node9\" class=\"node\"><title>140621852371264backward</title>\n",
       "<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"96,-112.719 96,-171.641 172.205,-171.641 172.205,-112.719 96,-112.719\"/>\n",
       "<text text-anchor=\"middle\" x=\"134.103\" y=\"-159.641\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;5.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"96,-152 172.205,-152 \"/>\n",
       "<text text-anchor=\"middle\" x=\"134.103\" y=\"-140\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"96,-132.359 172.205,-132.359 \"/>\n",
       "<text text-anchor=\"middle\" x=\"134.103\" y=\"-120.359\" font-family=\"Menlo\" font-size=\"10.00\">b</text>\n",
       "</g>\n",
       "<!-- 140621852371504backward&#45;&gt;140621852371264backward -->\n",
       "<g id=\"edge3\" class=\"edge\"><title>140621852371504backward&#45;&gt;140621852371264backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M65.2944,-224.727C76.9194,-210.95 90.9822,-194.284 103.357,-179.618\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"106.258,-181.607 110.032,-171.707 100.908,-177.093 106.258,-181.607\"/>\n",
       "<text text-anchor=\"middle\" x=\"115.174\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.00</text>\n",
       "</g>\n",
       "<!-- 140621852371072backward -->\n",
       "<g id=\"node2\" class=\"node\"><title>140621852371072backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"38.1025\" cy=\"-29.9609\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">x1=1.50</text>\n",
       "</g>\n",
       "<!-- 140621852371600backward -->\n",
       "<g id=\"node3\" class=\"node\"><title>140621852371600backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M105,-337.156C105,-337.156 157.205,-337.156 157.205,-337.156 163.205,-337.156 169.205,-343.156 169.205,-349.156 169.205,-349.156 169.205,-384.078 169.205,-384.078 169.205,-390.078 163.205,-396.078 157.205,-396.078 157.205,-396.078 105,-396.078 105,-396.078 99,-396.078 93,-390.078 93,-384.078 93,-384.078 93,-349.156 93,-349.156 93,-343.156 99,-337.156 105,-337.156\"/>\n",
       "<text text-anchor=\"middle\" x=\"131.103\" y=\"-384.078\" font-family=\"Menlo\" font-size=\"10.00\">grad=1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"93,-376.438 169.205,-376.438 \"/>\n",
       "<text text-anchor=\"middle\" x=\"131.103\" y=\"-364.438\" font-family=\"Menlo\" font-size=\"10.00\">value=1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"93,-356.797 169.205,-356.797 \"/>\n",
       "<text text-anchor=\"middle\" x=\"131.103\" y=\"-344.797\" font-family=\"Menlo\" font-size=\"10.00\">&#45;</text>\n",
       "</g>\n",
       "<!-- 140621852371600backward&#45;&gt;140621852371504backward -->\n",
       "<g id=\"edge5\" class=\"edge\"><title>140621852371600backward&#45;&gt;140621852371504backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M102.231,-336.929C96.6677,-331.03 91.0056,-324.755 85.959,-318.656 79.1297,-310.403 72.2123,-301.168 65.9212,-292.361\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"68.7212,-290.258 60.1048,-284.093 62.9959,-294.285 68.7212,-290.258\"/>\n",
       "<text text-anchor=\"middle\" x=\"107.174\" y=\"-307.456\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.00</text>\n",
       "</g>\n",
       "<!-- 140621852371360backward -->\n",
       "<g id=\"node12\" class=\"node\"><title>140621852371360backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"132.103\" cy=\"-254.398\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"132.103\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">y1=1.00</text>\n",
       "</g>\n",
       "<!-- 140621852371600backward&#45;&gt;140621852371360backward -->\n",
       "<g id=\"edge10\" class=\"edge\"><title>140621852371600backward&#45;&gt;140621852371360backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M131.363,-336.946C131.548,-316.535 131.791,-289.784 131.947,-272.59\"/>\n",
       "</g>\n",
       "<!-- 140621852371120backward&#45;&gt;140621852371072backward -->\n",
       "<g id=\"edge11\" class=\"edge\"><title>140621852371120backward&#45;&gt;140621852371072backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M38.1025,-112.509C38.1025,-92.0979 38.1025,-65.346 38.1025,-48.1524\"/>\n",
       "</g>\n",
       "<!-- 140621852371168backward -->\n",
       "<g id=\"node6\" class=\"node\"><title>140621852371168backward</title>\n",
       "<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"96,-0.5 96,-59.4219 172.205,-59.4219 172.205,-0.5 96,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"134.103\" y=\"-47.4219\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;9.50</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"96,-39.7812 172.205,-39.7812 \"/>\n",
       "<text text-anchor=\"middle\" x=\"134.103\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"96,-20.1406 172.205,-20.1406 \"/>\n",
       "<text text-anchor=\"middle\" x=\"134.103\" y=\"-8.14062\" font-family=\"Menlo\" font-size=\"10.00\">a</text>\n",
       "</g>\n",
       "<!-- 140621852371120backward&#45;&gt;140621852371168backward -->\n",
       "<g id=\"edge12\" class=\"edge\"><title>140621852371120backward&#45;&gt;140621852371168backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M63.0748,-112.509C75.0748,-98.7314 89.5912,-82.065 102.365,-67.3991\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"105.327,-69.3279 109.255,-59.4883 100.048,-64.7303 105.327,-69.3279\"/>\n",
       "<text text-anchor=\"middle\" x=\"113.174\" y=\"-83.0187\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.50</text>\n",
       "</g>\n",
       "<!-- 140621852371648backward -->\n",
       "<g id=\"node5\" class=\"node\"><title>140621852371648backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M203,-112.719C203,-112.719 255.205,-112.719 255.205,-112.719 261.205,-112.719 267.205,-118.719 267.205,-124.719 267.205,-124.719 267.205,-159.641 267.205,-159.641 267.205,-165.641 261.205,-171.641 255.205,-171.641 255.205,-171.641 203,-171.641 203,-171.641 197,-171.641 191,-165.641 191,-159.641 191,-159.641 191,-124.719 191,-124.719 191,-118.719 197,-112.719 203,-112.719\"/>\n",
       "<text text-anchor=\"middle\" x=\"229.103\" y=\"-159.641\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"191,-152 267.205,-152 \"/>\n",
       "<text text-anchor=\"middle\" x=\"229.103\" y=\"-140\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"191,-132.359 267.205,-132.359 \"/>\n",
       "<text text-anchor=\"middle\" x=\"229.103\" y=\"-120.359\" font-family=\"Menlo\" font-size=\"10.00\">*</text>\n",
       "</g>\n",
       "<!-- 140621852371648backward&#45;&gt;140621852371168backward -->\n",
       "<g id=\"edge4\" class=\"edge\"><title>140621852371648backward&#45;&gt;140621852371168backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M199.994,-112.634C194.217,-106.671 188.293,-100.335 182.959,-94.2188 175.627,-85.811 168.083,-76.4458 161.181,-67.5536\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"163.913,-65.3644 155.049,-59.5612 158.36,-69.6255 163.913,-65.3644\"/>\n",
       "<text text-anchor=\"middle\" x=\"204.174\" y=\"-83.0187\" font-family=\"Menlo\" font-size=\"14.00\">&#45;8.00</text>\n",
       "</g>\n",
       "<!-- 140621852371312backward -->\n",
       "<g id=\"node11\" class=\"node\"><title>140621852371312backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"229.103\" cy=\"-29.9609\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"229.103\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">x2=2.00</text>\n",
       "</g>\n",
       "<!-- 140621852371648backward&#45;&gt;140621852371312backward -->\n",
       "<g id=\"edge7\" class=\"edge\"><title>140621852371648backward&#45;&gt;140621852371312backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M229.103,-112.509C229.103,-92.0979 229.103,-65.346 229.103,-48.1524\"/>\n",
       "</g>\n",
       "<!-- 140621852371696backward -->\n",
       "<g id=\"node7\" class=\"node\"><title>140621852371696backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M197,-224.938C197,-224.938 249.205,-224.938 249.205,-224.938 255.205,-224.938 261.205,-230.938 261.205,-236.938 261.205,-236.938 261.205,-271.859 261.205,-271.859 261.205,-277.859 255.205,-283.859 249.205,-283.859 249.205,-283.859 197,-283.859 197,-283.859 191,-283.859 185,-277.859 185,-271.859 185,-271.859 185,-236.938 185,-236.938 185,-230.938 191,-224.938 197,-224.938\"/>\n",
       "<text text-anchor=\"middle\" x=\"223.103\" y=\"-271.859\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"185,-264.219 261.205,-264.219 \"/>\n",
       "<text text-anchor=\"middle\" x=\"223.103\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"185,-244.578 261.205,-244.578 \"/>\n",
       "<text text-anchor=\"middle\" x=\"223.103\" y=\"-232.578\" font-family=\"Menlo\" font-size=\"10.00\">+</text>\n",
       "</g>\n",
       "<!-- 140621852371696backward&#45;&gt;140621852371648backward -->\n",
       "<g id=\"edge2\" class=\"edge\"><title>140621852371696backward&#45;&gt;140621852371648backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M224.663,-224.727C225.372,-211.716 226.22,-196.127 226.985,-182.08\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"230.501,-181.883 227.55,-171.707 223.511,-181.502 230.501,-181.883\"/>\n",
       "<text text-anchor=\"middle\" x=\"248.174\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;4.00</text>\n",
       "</g>\n",
       "<!-- 140621852371696backward&#45;&gt;140621852371264backward -->\n",
       "<g id=\"edge9\" class=\"edge\"><title>140621852371696backward&#45;&gt;140621852371264backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M195.867,-224.806C190.467,-218.847 184.933,-212.524 179.959,-206.438 173.129,-198.08 166.115,-188.807 159.691,-179.994\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"162.423,-177.799 153.737,-171.731 156.743,-181.891 162.423,-177.799\"/>\n",
       "<text text-anchor=\"middle\" x=\"201.174\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;4.00</text>\n",
       "</g>\n",
       "<!-- 140621852371744backward -->\n",
       "<g id=\"node8\" class=\"node\"><title>140621852371744backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M199,-337.156C199,-337.156 251.205,-337.156 251.205,-337.156 257.205,-337.156 263.205,-343.156 263.205,-349.156 263.205,-349.156 263.205,-384.078 263.205,-384.078 263.205,-390.078 257.205,-396.078 251.205,-396.078 251.205,-396.078 199,-396.078 199,-396.078 193,-396.078 187,-390.078 187,-384.078 187,-384.078 187,-349.156 187,-349.156 187,-343.156 193,-337.156 199,-337.156\"/>\n",
       "<text text-anchor=\"middle\" x=\"225.103\" y=\"-384.078\" font-family=\"Menlo\" font-size=\"10.00\">grad=4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"187,-376.438 263.205,-376.438 \"/>\n",
       "<text text-anchor=\"middle\" x=\"225.103\" y=\"-364.438\" font-family=\"Menlo\" font-size=\"10.00\">value=4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"187,-356.797 263.205,-356.797 \"/>\n",
       "<text text-anchor=\"middle\" x=\"225.103\" y=\"-344.797\" font-family=\"Menlo\" font-size=\"10.00\">&#45;</text>\n",
       "</g>\n",
       "<!-- 140621852371744backward&#45;&gt;140621852371696backward -->\n",
       "<g id=\"edge1\" class=\"edge\"><title>140621852371744backward&#45;&gt;140621852371696backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M224.582,-336.946C224.346,-323.934 224.063,-308.345 223.808,-294.299\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"227.301,-293.861 223.62,-283.926 220.302,-293.988 227.301,-293.861\"/>\n",
       "<text text-anchor=\"middle\" x=\"246.174\" y=\"-307.456\" font-family=\"Menlo\" font-size=\"14.00\">&#45;4.00</text>\n",
       "</g>\n",
       "<!-- 140621852371408backward -->\n",
       "<g id=\"node13\" class=\"node\"><title>140621852371408backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"314.103\" cy=\"-254.398\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"314.103\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">y2=4.00</text>\n",
       "</g>\n",
       "<!-- 140621852371744backward&#45;&gt;140621852371408backward -->\n",
       "<g id=\"edge8\" class=\"edge\"><title>140621852371744backward&#45;&gt;140621852371408backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M254.651,-337.059C260.309,-331.15 266.042,-324.838 271.103,-318.656 283.468,-303.552 295.811,-284.936 304.08,-271.822\"/>\n",
       "</g>\n",
       "<!-- 140621852371792backward -->\n",
       "<g id=\"node10\" class=\"node\"><title>140621852371792backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M156,-449.375C156,-449.375 208.205,-449.375 208.205,-449.375 214.205,-449.375 220.205,-455.375 220.205,-461.375 220.205,-461.375 220.205,-496.297 220.205,-496.297 220.205,-502.297 214.205,-508.297 208.205,-508.297 208.205,-508.297 156,-508.297 156,-508.297 150,-508.297 144,-502.297 144,-496.297 144,-496.297 144,-461.375 144,-461.375 144,-455.375 150,-449.375 156,-449.375\"/>\n",
       "<text text-anchor=\"middle\" x=\"182.103\" y=\"-496.297\" font-family=\"Menlo\" font-size=\"10.00\">grad=1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"144,-488.656 220.205,-488.656 \"/>\n",
       "<text text-anchor=\"middle\" x=\"182.103\" y=\"-476.656\" font-family=\"Menlo\" font-size=\"10.00\">value=8.50</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"144,-469.016 220.205,-469.016 \"/>\n",
       "<text text-anchor=\"middle\" x=\"182.103\" y=\"-457.016\" font-family=\"Menlo\" font-size=\"10.00\">mse</text>\n",
       "</g>\n",
       "<!-- 140621852371792backward&#45;&gt;140621852371600backward -->\n",
       "<g id=\"edge13\" class=\"edge\"><title>140621852371792backward&#45;&gt;140621852371600backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M168.836,-449.165C162.638,-435.77 155.177,-419.645 148.531,-405.282\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"151.679,-403.75 144.303,-396.145 145.326,-406.69 151.679,-403.75\"/>\n",
       "<text text-anchor=\"middle\" x=\"176.96\" y=\"-419.675\" font-family=\"Menlo\" font-size=\"14.00\">1.00</text>\n",
       "</g>\n",
       "<!-- 140621852371792backward&#45;&gt;140621852371744backward -->\n",
       "<g id=\"edge14\" class=\"edge\"><title>140621852371792backward&#45;&gt;140621852371744backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M193.288,-449.165C198.464,-435.898 204.685,-419.952 210.248,-405.693\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"213.599,-406.733 213.973,-396.145 207.078,-404.189 213.599,-406.733\"/>\n",
       "<text text-anchor=\"middle\" x=\"223.96\" y=\"-419.675\" font-family=\"Menlo\" font-size=\"14.00\">4.00</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fe5138ee280>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 第一次触发方向传播\n",
    "loss.backward()\n",
    "draw_graph(loss, 'backward')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
       " -->\n",
       "<!-- Title: %3 Pages: 1 -->\n",
       "<svg width=\"360pt\" height=\"517pt\"\n",
       " viewBox=\"0.00 0.00 359.79 516.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 512.797)\">\n",
       "<title>%3</title>\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-512.797 355.792,-512.797 355.792,4 -4,4\"/>\n",
       "<!-- 140621852371504backward -->\n",
       "<g id=\"node1\" class=\"node\"><title>140621852371504backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M18,-224.938C18,-224.938 70.2051,-224.938 70.2051,-224.938 76.2051,-224.938 82.2051,-230.938 82.2051,-236.938 82.2051,-236.938 82.2051,-271.859 82.2051,-271.859 82.2051,-277.859 76.2051,-283.859 70.2051,-283.859 70.2051,-283.859 18,-283.859 18,-283.859 12,-283.859 6,-277.859 6,-271.859 6,-271.859 6,-236.938 6,-236.938 6,-230.938 12,-224.938 18,-224.938\"/>\n",
       "<text text-anchor=\"middle\" x=\"44.1025\" y=\"-271.859\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;2.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"6,-264.219 82.2051,-264.219 \"/>\n",
       "<text text-anchor=\"middle\" x=\"44.1025\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"6,-244.578 82.2051,-244.578 \"/>\n",
       "<text text-anchor=\"middle\" x=\"44.1025\" y=\"-232.578\" font-family=\"Menlo\" font-size=\"10.00\">+</text>\n",
       "</g>\n",
       "<!-- 140621852371120backward -->\n",
       "<g id=\"node4\" class=\"node\"><title>140621852371120backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M12,-112.719C12,-112.719 64.2051,-112.719 64.2051,-112.719 70.2051,-112.719 76.2051,-118.719 76.2051,-124.719 76.2051,-124.719 76.2051,-159.641 76.2051,-159.641 76.2051,-165.641 70.2051,-171.641 64.2051,-171.641 64.2051,-171.641 12,-171.641 12,-171.641 6,-171.641 0,-165.641 0,-159.641 0,-159.641 0,-124.719 0,-124.719 0,-118.719 6,-112.719 12,-112.719\"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-159.641\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;2.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"0,-152 76.2051,-152 \"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-140\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"0,-132.359 76.2051,-132.359 \"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-120.359\" font-family=\"Menlo\" font-size=\"10.00\">*</text>\n",
       "</g>\n",
       "<!-- 140621852371504backward&#45;&gt;140621852371120backward -->\n",
       "<g id=\"edge6\" class=\"edge\"><title>140621852371504backward&#45;&gt;140621852371120backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M40.5115,-224.627C39.886,-218.625 39.3208,-212.332 38.959,-206.438 38.4716,-198.498 38.1914,-189.988 38.0402,-181.866\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"41.5399,-181.812 37.9118,-171.857 34.5405,-181.901 41.5399,-181.812\"/>\n",
       "<text text-anchor=\"middle\" x=\"60.1743\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.00</text>\n",
       "</g>\n",
       "<!-- 140621852371264backward -->\n",
       "<g id=\"node9\" class=\"node\"><title>140621852371264backward</title>\n",
       "<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"93.9897,-112.719 93.9897,-171.641 176.215,-171.641 176.215,-112.719 93.9897,-112.719\"/>\n",
       "<text text-anchor=\"middle\" x=\"135.103\" y=\"-159.641\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;10.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"93.9897,-152 176.215,-152 \"/>\n",
       "<text text-anchor=\"middle\" x=\"135.103\" y=\"-140\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"93.9897,-132.359 176.215,-132.359 \"/>\n",
       "<text text-anchor=\"middle\" x=\"135.103\" y=\"-120.359\" font-family=\"Menlo\" font-size=\"10.00\">b</text>\n",
       "</g>\n",
       "<!-- 140621852371504backward&#45;&gt;140621852371264backward -->\n",
       "<g id=\"edge3\" class=\"edge\"><title>140621852371504backward&#45;&gt;140621852371264backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M67.7742,-224.727C79.1491,-210.95 92.9095,-194.284 105.018,-179.618\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"107.882,-181.647 111.549,-171.707 102.484,-177.19 107.882,-181.647\"/>\n",
       "<text text-anchor=\"middle\" x=\"116.174\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.00</text>\n",
       "</g>\n",
       "<!-- 140621852371072backward -->\n",
       "<g id=\"node2\" class=\"node\"><title>140621852371072backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"38.1025\" cy=\"-29.9609\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"38.1025\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">x1=1.50</text>\n",
       "</g>\n",
       "<!-- 140621852371600backward -->\n",
       "<g id=\"node3\" class=\"node\"><title>140621852371600backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M108,-337.156C108,-337.156 160.205,-337.156 160.205,-337.156 166.205,-337.156 172.205,-343.156 172.205,-349.156 172.205,-349.156 172.205,-384.078 172.205,-384.078 172.205,-390.078 166.205,-396.078 160.205,-396.078 160.205,-396.078 108,-396.078 108,-396.078 102,-396.078 96,-390.078 96,-384.078 96,-384.078 96,-349.156 96,-349.156 96,-343.156 102,-337.156 108,-337.156\"/>\n",
       "<text text-anchor=\"middle\" x=\"134.103\" y=\"-384.078\" font-family=\"Menlo\" font-size=\"10.00\">grad=2.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"96,-376.438 172.205,-376.438 \"/>\n",
       "<text text-anchor=\"middle\" x=\"134.103\" y=\"-364.438\" font-family=\"Menlo\" font-size=\"10.00\">value=1.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"96,-356.797 172.205,-356.797 \"/>\n",
       "<text text-anchor=\"middle\" x=\"134.103\" y=\"-344.797\" font-family=\"Menlo\" font-size=\"10.00\">&#45;</text>\n",
       "</g>\n",
       "<!-- 140621852371600backward&#45;&gt;140621852371504backward -->\n",
       "<g id=\"edge5\" class=\"edge\"><title>140621852371600backward&#45;&gt;140621852371504backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M105.231,-336.929C99.6677,-331.03 94.0056,-324.755 88.959,-318.656 82.1297,-310.403 75.2123,-301.168 68.9212,-292.361\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"71.7212,-290.258 63.1048,-284.093 65.9959,-294.285 71.7212,-290.258\"/>\n",
       "<text text-anchor=\"middle\" x=\"110.174\" y=\"-307.456\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.00</text>\n",
       "</g>\n",
       "<!-- 140621852371360backward -->\n",
       "<g id=\"node12\" class=\"node\"><title>140621852371360backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"135.103\" cy=\"-254.398\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"135.103\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">y1=1.00</text>\n",
       "</g>\n",
       "<!-- 140621852371600backward&#45;&gt;140621852371360backward -->\n",
       "<g id=\"edge10\" class=\"edge\"><title>140621852371600backward&#45;&gt;140621852371360backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M134.363,-336.946C134.548,-316.535 134.791,-289.784 134.947,-272.59\"/>\n",
       "</g>\n",
       "<!-- 140621852371120backward&#45;&gt;140621852371072backward -->\n",
       "<g id=\"edge11\" class=\"edge\"><title>140621852371120backward&#45;&gt;140621852371072backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M38.1025,-112.509C38.1025,-92.0979 38.1025,-65.346 38.1025,-48.1524\"/>\n",
       "</g>\n",
       "<!-- 140621852371168backward -->\n",
       "<g id=\"node6\" class=\"node\"><title>140621852371168backward</title>\n",
       "<polygon fill=\"lightgreen\" stroke=\"black\" stroke-width=\"2\" points=\"93.9897,-0.5 93.9897,-59.4219 176.215,-59.4219 176.215,-0.5 93.9897,-0.5\"/>\n",
       "<text text-anchor=\"middle\" x=\"135.103\" y=\"-47.4219\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;19.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"93.9897,-39.7812 176.215,-39.7812 \"/>\n",
       "<text text-anchor=\"middle\" x=\"135.103\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" stroke-width=\"2\" points=\"93.9897,-20.1406 176.215,-20.1406 \"/>\n",
       "<text text-anchor=\"middle\" x=\"135.103\" y=\"-8.14062\" font-family=\"Menlo\" font-size=\"10.00\">a</text>\n",
       "</g>\n",
       "<!-- 140621852371120backward&#45;&gt;140621852371168backward -->\n",
       "<g id=\"edge12\" class=\"edge\"><title>140621852371120backward&#45;&gt;140621852371168backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M63.3349,-112.509C75.4599,-98.7314 90.1275,-82.065 103.034,-67.3991\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"106.017,-69.3076 109.996,-59.4883 100.762,-64.6829 106.017,-69.3076\"/>\n",
       "<text text-anchor=\"middle\" x=\"114.174\" y=\"-83.0187\" font-family=\"Menlo\" font-size=\"14.00\">&#45;1.50</text>\n",
       "</g>\n",
       "<!-- 140621852371648backward -->\n",
       "<g id=\"node5\" class=\"node\"><title>140621852371648backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M206,-112.719C206,-112.719 258.205,-112.719 258.205,-112.719 264.205,-112.719 270.205,-118.719 270.205,-124.719 270.205,-124.719 270.205,-159.641 270.205,-159.641 270.205,-165.641 264.205,-171.641 258.205,-171.641 258.205,-171.641 206,-171.641 206,-171.641 200,-171.641 194,-165.641 194,-159.641 194,-159.641 194,-124.719 194,-124.719 194,-118.719 200,-112.719 206,-112.719\"/>\n",
       "<text text-anchor=\"middle\" x=\"232.103\" y=\"-159.641\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;8.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"194,-152 270.205,-152 \"/>\n",
       "<text text-anchor=\"middle\" x=\"232.103\" y=\"-140\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"194,-132.359 270.205,-132.359 \"/>\n",
       "<text text-anchor=\"middle\" x=\"232.103\" y=\"-120.359\" font-family=\"Menlo\" font-size=\"10.00\">*</text>\n",
       "</g>\n",
       "<!-- 140621852371648backward&#45;&gt;140621852371168backward -->\n",
       "<g id=\"edge4\" class=\"edge\"><title>140621852371648backward&#45;&gt;140621852371168backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M203.087,-112.553C197.303,-106.596 191.352,-100.283 185.959,-94.2188 178.431,-85.7538 170.616,-76.3673 163.436,-67.4728\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"166.027,-65.1078 157.048,-59.4832 160.559,-69.4792 166.027,-65.1078\"/>\n",
       "<text text-anchor=\"middle\" x=\"207.174\" y=\"-83.0187\" font-family=\"Menlo\" font-size=\"14.00\">&#45;8.00</text>\n",
       "</g>\n",
       "<!-- 140621852371312backward -->\n",
       "<g id=\"node11\" class=\"node\"><title>140621852371312backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"232.103\" cy=\"-29.9609\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"232.103\" y=\"-27.7812\" font-family=\"Menlo\" font-size=\"10.00\">x2=2.00</text>\n",
       "</g>\n",
       "<!-- 140621852371648backward&#45;&gt;140621852371312backward -->\n",
       "<g id=\"edge7\" class=\"edge\"><title>140621852371648backward&#45;&gt;140621852371312backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M232.103,-112.509C232.103,-92.0979 232.103,-65.346 232.103,-48.1524\"/>\n",
       "</g>\n",
       "<!-- 140621852371696backward -->\n",
       "<g id=\"node7\" class=\"node\"><title>140621852371696backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M200,-224.938C200,-224.938 252.205,-224.938 252.205,-224.938 258.205,-224.938 264.205,-230.938 264.205,-236.938 264.205,-236.938 264.205,-271.859 264.205,-271.859 264.205,-277.859 258.205,-283.859 252.205,-283.859 252.205,-283.859 200,-283.859 200,-283.859 194,-283.859 188,-277.859 188,-271.859 188,-271.859 188,-236.938 188,-236.938 188,-230.938 194,-224.938 200,-224.938\"/>\n",
       "<text text-anchor=\"middle\" x=\"226.103\" y=\"-271.859\" font-family=\"Menlo\" font-size=\"10.00\">grad=&#45;8.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"188,-264.219 264.205,-264.219 \"/>\n",
       "<text text-anchor=\"middle\" x=\"226.103\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">value=0.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"188,-244.578 264.205,-244.578 \"/>\n",
       "<text text-anchor=\"middle\" x=\"226.103\" y=\"-232.578\" font-family=\"Menlo\" font-size=\"10.00\">+</text>\n",
       "</g>\n",
       "<!-- 140621852371696backward&#45;&gt;140621852371648backward -->\n",
       "<g id=\"edge2\" class=\"edge\"><title>140621852371696backward&#45;&gt;140621852371648backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M227.663,-224.727C228.372,-211.716 229.22,-196.127 229.985,-182.08\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"233.501,-181.883 230.55,-171.707 226.511,-181.502 233.501,-181.883\"/>\n",
       "<text text-anchor=\"middle\" x=\"251.174\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;4.00</text>\n",
       "</g>\n",
       "<!-- 140621852371696backward&#45;&gt;140621852371264backward -->\n",
       "<g id=\"edge9\" class=\"edge\"><title>140621852371696backward&#45;&gt;140621852371264backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M198.964,-224.726C193.556,-218.773 187.995,-212.473 182.959,-206.438 175.939,-198.024 168.659,-188.729 161.96,-179.914\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"164.552,-177.538 155.742,-171.654 158.96,-181.748 164.552,-177.538\"/>\n",
       "<text text-anchor=\"middle\" x=\"204.174\" y=\"-195.238\" font-family=\"Menlo\" font-size=\"14.00\">&#45;4.00</text>\n",
       "</g>\n",
       "<!-- 140621852371744backward -->\n",
       "<g id=\"node8\" class=\"node\"><title>140621852371744backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M202,-337.156C202,-337.156 254.205,-337.156 254.205,-337.156 260.205,-337.156 266.205,-343.156 266.205,-349.156 266.205,-349.156 266.205,-384.078 266.205,-384.078 266.205,-390.078 260.205,-396.078 254.205,-396.078 254.205,-396.078 202,-396.078 202,-396.078 196,-396.078 190,-390.078 190,-384.078 190,-384.078 190,-349.156 190,-349.156 190,-343.156 196,-337.156 202,-337.156\"/>\n",
       "<text text-anchor=\"middle\" x=\"228.103\" y=\"-384.078\" font-family=\"Menlo\" font-size=\"10.00\">grad=8.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"190,-376.438 266.205,-376.438 \"/>\n",
       "<text text-anchor=\"middle\" x=\"228.103\" y=\"-364.438\" font-family=\"Menlo\" font-size=\"10.00\">value=4.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"190,-356.797 266.205,-356.797 \"/>\n",
       "<text text-anchor=\"middle\" x=\"228.103\" y=\"-344.797\" font-family=\"Menlo\" font-size=\"10.00\">&#45;</text>\n",
       "</g>\n",
       "<!-- 140621852371744backward&#45;&gt;140621852371696backward -->\n",
       "<g id=\"edge1\" class=\"edge\"><title>140621852371744backward&#45;&gt;140621852371696backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M227.582,-336.946C227.346,-323.934 227.063,-308.345 226.808,-294.299\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"230.301,-293.861 226.62,-283.926 223.302,-293.988 230.301,-293.861\"/>\n",
       "<text text-anchor=\"middle\" x=\"249.174\" y=\"-307.456\" font-family=\"Menlo\" font-size=\"14.00\">&#45;4.00</text>\n",
       "</g>\n",
       "<!-- 140621852371408backward -->\n",
       "<g id=\"node13\" class=\"node\"><title>140621852371408backward</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"317.103\" cy=\"-254.398\" rx=\"34.8795\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"317.103\" y=\"-252.219\" font-family=\"Menlo\" font-size=\"10.00\">y2=4.00</text>\n",
       "</g>\n",
       "<!-- 140621852371744backward&#45;&gt;140621852371408backward -->\n",
       "<g id=\"edge8\" class=\"edge\"><title>140621852371744backward&#45;&gt;140621852371408backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" stroke-dasharray=\"5,2\" d=\"M257.651,-337.059C263.309,-331.15 269.042,-324.838 274.103,-318.656 286.468,-303.552 298.811,-284.936 307.08,-271.822\"/>\n",
       "</g>\n",
       "<!-- 140621852371792backward -->\n",
       "<g id=\"node10\" class=\"node\"><title>140621852371792backward</title>\n",
       "<path fill=\"#f0f0f0\" stroke=\"black\" d=\"M159,-449.375C159,-449.375 211.205,-449.375 211.205,-449.375 217.205,-449.375 223.205,-455.375 223.205,-461.375 223.205,-461.375 223.205,-496.297 223.205,-496.297 223.205,-502.297 217.205,-508.297 211.205,-508.297 211.205,-508.297 159,-508.297 159,-508.297 153,-508.297 147,-502.297 147,-496.297 147,-496.297 147,-461.375 147,-461.375 147,-455.375 153,-449.375 159,-449.375\"/>\n",
       "<text text-anchor=\"middle\" x=\"185.103\" y=\"-496.297\" font-family=\"Menlo\" font-size=\"10.00\">grad=2.00</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"147,-488.656 223.205,-488.656 \"/>\n",
       "<text text-anchor=\"middle\" x=\"185.103\" y=\"-476.656\" font-family=\"Menlo\" font-size=\"10.00\">value=8.50</text>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"147,-469.016 223.205,-469.016 \"/>\n",
       "<text text-anchor=\"middle\" x=\"185.103\" y=\"-457.016\" font-family=\"Menlo\" font-size=\"10.00\">mse</text>\n",
       "</g>\n",
       "<!-- 140621852371792backward&#45;&gt;140621852371600backward -->\n",
       "<g id=\"edge13\" class=\"edge\"><title>140621852371792backward&#45;&gt;140621852371600backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M171.836,-449.165C165.638,-435.77 158.177,-419.645 151.531,-405.282\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"154.679,-403.75 147.303,-396.145 148.326,-406.69 154.679,-403.75\"/>\n",
       "<text text-anchor=\"middle\" x=\"179.96\" y=\"-419.675\" font-family=\"Menlo\" font-size=\"14.00\">1.00</text>\n",
       "</g>\n",
       "<!-- 140621852371792backward&#45;&gt;140621852371744backward -->\n",
       "<g id=\"edge14\" class=\"edge\"><title>140621852371792backward&#45;&gt;140621852371744backward</title>\n",
       "<path fill=\"none\" stroke=\"deepskyblue\" d=\"M196.288,-449.165C201.464,-435.898 207.685,-419.952 213.248,-405.693\"/>\n",
       "<polygon fill=\"deepskyblue\" stroke=\"deepskyblue\" points=\"216.599,-406.733 216.973,-396.145 210.078,-404.189 216.599,-406.733\"/>\n",
       "<text text-anchor=\"middle\" x=\"226.96\" y=\"-419.675\" font-family=\"Menlo\" font-size=\"14.00\">4.00</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x7fe5138ee250>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 第二次触发方向传播\n",
    "loss.backward()\n",
    "draw_graph(loss, 'backward')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "# 固定随机种子，使得运行结果可以稳定复现\n",
    "torch.manual_seed(1024)\n",
    "# 产生训练用的数据\n",
    "x_origin = torch.linspace(100, 300, 200)\n",
    "# 将变量X归一化，否则梯度下降法很容易不稳定\n",
    "x = (x_origin - torch.mean(x_origin)) / torch.std(x_origin)\n",
    "epsilon = torch.randn(x.shape)\n",
    "y = 10 * x + 5 + epsilon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Step 1, Result: y = 3.12 * x + -1.99\n",
      "Step 2, Result: y = 3.48 * x + -2.28\n",
      "Step 3, Result: y = 3.22 * x + -1.97\n",
      "Step 4, Result: y = 2.85 * x + -1.22\n",
      "Step 5, Result: y = 2.68 * x + -0.23\n",
      "Step 6, Result: y = 2.92 * x + 1.08\n",
      "Step 7, Result: y = 3.74 * x + 2.61\n",
      "Step 8, Result: y = 5.07 * x + 4.15\n",
      "Step 9, Result: y = 6.73 * x + 5.52\n",
      "Step 10, Result: y = 8.22 * x + 6.48\n",
      "Step 11, Result: y = 9.36 * x + 5.75\n",
      "Step 12, Result: y = 9.75 * x + 5.42\n",
      "Step 13, Result: y = 9.88 * x + 5.28\n",
      "Step 14, Result: y = 9.89 * x + 5.26\n",
      "Step 15, Result: y = 9.89 * x + 5.20\n",
      "Step 16, Result: y = 9.88 * x + 5.18\n",
      "Step 17, Result: y = 9.88 * x + 5.17\n",
      "Step 18, Result: y = 9.84 * x + 5.14\n",
      "Step 19, Result: y = 9.86 * x + 5.15\n",
      "Step 20, Result: y = 9.94 * x + 5.21\n"
     ]
    }
   ],
   "source": [
    "# 生成模型\n",
    "model = Linear()\n",
    "# 定义每批次用到的数据量\n",
    "batch_size = 20\n",
    "learning_rate = 0.1\n",
    "\n",
    "for t in range(20):\n",
    "    # 选取当前批次的数据，用于训练模型\n",
    "    ix = (t * batch_size) % len(x)\n",
    "    xx = x[ix: ix + batch_size]\n",
    "    yy = y[ix: ix + batch_size]\n",
    "    # 计算当前批次数据的损失\n",
    "    loss = mse([model.error(_x, _y) for _x, _y in zip(xx, yy)])\n",
    "    # 计算损失函数的梯度\n",
    "    loss.backward()\n",
    "    # 迭代更新模型参数的估计值\n",
    "    model.a -= learning_rate * model.a.grad\n",
    "    model.b -= learning_rate * model.b.grad\n",
    "    # 将使用完的梯度清零\n",
    "    model.a.grad = 0.0\n",
    "    model.b.grad = 0.0\n",
    "    print(f'Step {t + 1}, Result: {model.string()}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
